package fun.ticsmyc.rpc.server.transport.bio;

import fun.ticsmyc.rpc.common.entity.RpcRequest;
import fun.ticsmyc.rpc.common.entity.RpcResponse;
import fun.ticsmyc.rpc.common.enumeration.ResponseCode;
import fun.ticsmyc.rpc.common.enumeration.RpcError;
import fun.ticsmyc.rpc.common.exception.RpcException;
import fun.ticsmyc.rpc.common.serializer.Serializer;
import fun.ticsmyc.rpc.common.serializer.Serializers;
import fun.ticsmyc.rpc.server.handler.RpcRequestHandler;
import fun.ticsmyc.rpc.server.provider.ServiceProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;

/**
 * 接收到客户端调用请求后的处理逻辑
 *
 * @author Ticsmyc
 * @date 2020-10-23 10:10
 */
public class RequestHandler implements Runnable {
    private final int MAGIC_NUMBER = 0xCAFEBABE;
    private final byte CURRENT_PROTOCOL_VERSION = 1;
    private static final Logger logger = LoggerFactory.getLogger(RequestHandler.class);

    private Socket socket;

    private ServiceProvider registry;
    private Serializer serializer;


    public RequestHandler(Socket socket, ServiceProvider registry, Serializer serializer){
        this.socket = socket;
        this.registry =registry;
        this.serializer = serializer;
    }

    @Override
    public void run() {
        InputStream inputStream=null;
        OutputStream outputStream=null;
        try{
            //读取用户请求 ： 根据协议，用户传来的是一个RpcRequest对象
                //创建输入输出流
            inputStream= socket.getInputStream();
            outputStream= socket.getOutputStream();

            //获取输入
            RpcRequest rpcRequest = getRpcRequest(inputStream);

            //执行rpc请求，获取结果
            RpcResponse<Object> invokeResult;
            try{
                invokeResult = RpcResponse.success(RpcRequestHandler.invokeMethod(registry,rpcRequest),rpcRequest.getRequestId());
            }catch(RpcException e){
                invokeResult = RpcResponse.fail(ResponseCode.METHOD_NOT_FOUND,rpcRequest.getRequestId());
            }

            //返回结果
            returnResult(outputStream, invokeResult);

        }catch(Exception e){
            e.printStackTrace();
        }finally {
            try {
                if(inputStream!= null) inputStream.close();
                if(outputStream!= null) outputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void returnResult(OutputStream outputStream, RpcResponse<Object> invokeResult) throws IOException {
        byte[] response = serializer.serialize(invokeResult);
        outputStream.write(intToBytes(MAGIC_NUMBER));
        outputStream.write(new byte[]{CURRENT_PROTOCOL_VERSION});
        outputStream.write(new byte[]{serializer.getId()});
        outputStream.write(intToBytes(response.length));
        outputStream.write(response);
        outputStream.flush();
    }

    private RpcRequest getRpcRequest(InputStream inputStream) throws IOException {
        //处理输入
        byte[] numberBytes = new byte[4];
        inputStream.read(numberBytes);
        int magic = byteToInt(numberBytes);
        if(magic != MAGIC_NUMBER){
            logger.error("不识别的协议包,{}",magic);
            throw new RpcException(RpcError.UNKNOW_PROTOCOL);
        }

        byte[] protocolVersion = new byte[1];
        inputStream.read(protocolVersion);
        if(protocolVersion[0] > CURRENT_PROTOCOL_VERSION){
            logger.error("不能解析的协议版本:{}",protocolVersion[0]);
            throw new RpcException(RpcError.UNKNOW_PROTOCOL);
        }
        //序列化器
        byte[] serializerCode = new byte[1];
        inputStream.read(serializerCode);
        Serializer serializer = Serializers.getSerializerByCode(serializerCode[0]);
        if(serializer == null){
            logger.error("不能识别的序列化器,{}",serializerCode);
            throw new RpcException(RpcError.UNKNOWN_SERIALIZER);
        }

        //数据包长度和数据包
        inputStream.read(numberBytes);
        int length = byteToInt(numberBytes);
        byte[] bytes = new byte[length];
        inputStream.read(bytes);
        return (RpcRequest)serializer.deserialize(bytes, RpcRequest.class);
    }

    private int byteToInt(byte[] src) {
        int value;
        value = ((src[0] & 0xFF)<<24)
                |((src[1] & 0xFF)<<16)
                |((src[2] & 0xFF)<<8)
                |(src[3] & 0xFF);
        return value;
    }
    private static byte[] intToBytes(int value) {
        byte[] src = new byte[4];
        src[0] = (byte) ((value>>24) & 0xFF);
        src[1] = (byte) ((value>>16)& 0xFF);
        src[2] = (byte) ((value>>8)&0xFF);
        src[3] = (byte) (value & 0xFF);
        return src;
    }
}
